package baidu;

import java.util.Scanner;

public class Main {
    public static void main(String[] args) {
        fisrt();
    }

    private static void fisrt() {
        Scanner sc = new Scanner(System.in);
        long MOD = (long) 10e9 + 7;
        int N = sc.nextInt();
        long M = sc.nextLong();
        if(N==1){
            System.out.println(1);
            return;
        }else if(N==2){
            System.out.println(M-1);
            return;
        }else if(N==3){
            System.out.println((M-1)*(M-2));
        }

        long Q = M - 1;
        long res = Q;
        for (int i = 2; i < N - 2; i++) {
            if ((i & 1) == 1) {
                res *= M;
            } else {
                res *= Q;
            }
            if (res > MOD) res %= MOD;
        }
        if (((N - 1) & 1) == 1) {
            res = res * 1 * (M - 1) + res * (Q - 1) * (M - 2);
        } else {
            res = res * 1 * (M - 1) + res * (M - 1) * (M - 2);
        }

        if (res > MOD) res %= MOD;
        System.out.println(res);
    }
}
